import os
import json
import re
import statistics
import numpy as np
from utils.dataset_utils import get_dataset
from utils import grader
import random
from collections import defaultdict
import itertools
import math

def generate_model_acc(model, test_dataset="gsm8k", train_dataset="gsm8k", output_file="gsm8k_acc.jsonl"):
    embs = ["all-roberta-large-v1"]
    ks = [0]
    methods = ["knn"]
    
    base_dir = './results'
    
    _, _, _, gold_answers = get_dataset(dataset=test_dataset, load_from_local=True)
    
    results = []
    
    file_dir = f"{base_dir}/{methods[0]}/{model}/{test_dataset}/{train_dataset}/{ks[0]}/0/0/{embs[0]}.jsonl"
    
    if not os.path.exists(file_dir):
        print(f"File {file_dir} does not exist")
        return
        
    pred_answers = read_jsonl(file_dir=file_dir, key="answer")
    if len(pred_answers) != len(gold_answers):
        print(f"Warning: Number of predictions in {file_dir} does not match dataset size")
        return
    
    total_correct = 0
    print(f"Dataset size: {len(gold_answers)}")
    for i, gold_answer in enumerate(gold_answers):
        gold_answer_text = extract_answer_gold(gold_answer)
        if gold_answer_text:
            gold_answer_text = gold_answer_text.replace(',', '')
        
        pred_answer = extract_answer_pred(pred_answers[i], "", i, model)
        
        if pred_answer:
            pred_answer = pred_answer.replace(',', '')
        
        correct = grader.grade_answer(pred_answer, gold_answer_text)
        
        total_correct += correct
        
        if correct:
            correct = 1
        else:
            correct = 0
        results.append({
            "index": i,
            "question": gold_answer,
            "gold_answer": gold_answer_text,
            f"pred_{model}": pred_answer,
            "acc": correct
        })
    
    accuracy = total_correct / len(gold_answers)
    print(f"{model} model accuracy: {accuracy:.4f}")
    
    with open(output_file, "w", encoding="utf-8") as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + "\n")
    
    print(f"Results written to {output_file}")
    return accuracy

def print_knn_permutation(model, method, k, accuracy_dict):
    seed = 0
    np.random.seed(seed)
    random.seed(seed)
    permutations = []
    length = 0

    if(k>3):
        per_total = 10
    else:
        per_total = math.factorial(k)
    while length < per_total:
        perm = np.random.permutation(k)
        if(perm.tolist() not in permutations):
            permutations.append(perm.tolist())
            length += 1
    for i in range(len(permutations)):
        print(f"permutation {i}: {permutations[i]}")
        print(f"{accuracy_dict[model]['id'][method][k][i]:.4f}")
    
def extract_answer_gold(text):
    pattern = r"####\s(.*)"
    match = re.search(pattern, text, re.DOTALL)
    
    if match:
        answer = match.group(1).strip()
        if answer.lower() == "none":
            return "None"
        return answer
    else:
        return None

def extract_answer_pred(text,file_dir,i, model):
    pattern_full = r"####\s*(.*?)(?:\n####|\n\n|$)"
    match = re.search(pattern_full, text, re.DOTALL)
    if match:
        answer = match.group(1).strip()
        if answer.lower() == "none":
            return "None"
        return answer
    
    pattern_boxed = r"\\boxed{(.*?)}"
    match = re.search(pattern_boxed, text)
    if match:
        answer = match.group(1).strip()
        if answer.lower() == "none":
            return "None"
        return answer
        
    return None

def count_newlines(text):
    return text.count('\n')

def read_jsonl(file_dir, key='question'):
    results = []
    with open(file_dir, 'r', encoding='utf-8') as file:
        for line in file:
            data = json.loads(line)
            results.append(data[key])
    return results

def extract_answer_pred_old(text):
    pattern_full = r"####\s(.*?)\n\n"
    match = re.search(pattern_full, text, re.DOTALL)
    if match:
        return extract_digits(match.group(1).strip())
    else:
        pattern_fallback = r"####\s(.*)"
        match_fallback = re.search(pattern_fallback, text, re.DOTALL)
        
        if match_fallback:
            pattern_number = r"-?\d[\d,.]*"
            match_number = re.search(pattern_number, match_fallback.group(1))
            
            if match_number:
                return match_number.group(0).strip()
        
        return None

def extract_digits(text):
    text = text.replace(',', '')
    pattern = r'(?<!\d)-?(\d+(\.\d+)?)'
    matches = re.findall(pattern, text)
    result = ''.join(match[0] for match in matches)
    return result

def calculate_accuracy(gold_answers, pred_answers,file_dir, test_dataset, model):
    correct = 1
    null = 0
    bracket_class = 0
    bracket_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
    for i, data in enumerate(gold_answers):

        if(test_dataset == "gsm8k" or test_dataset == "prm800k" or test_dataset == "gsm8k-1000" or test_dataset == "prm800k-1000"):
            bracket_count = pred_answers[i].count('\n')
            if (bracket_count <= 3):
                bracket_class = 1
            elif(bracket_count > 3):
                bracket_class = 2
        bracket_stats[bracket_class]['total'] += 1

        gold_answer = extract_answer_gold(data)
        try:
            pred_answer = extract_answer_pred(pred_answers[i], file_dir, i, model)
        except:
            print(f"Error extracting from file {file_dir}!")
            print(pred_answers[i])
            continue

        if pred_answer:
            pred_answer = pred_answer.replace(',', '')
        gold_answer = gold_answer.replace(',', '')

        if test_dataset == "gsm8k-plus-mini" and (pred_answer == "#### None" or gold_answer == "#### None"):
            if pred_answer == gold_answer:
                correct += 1
            else:
                null += 1
        elif grader.grade_answer(pred_answer, gold_answer):
            correct += 1
            if(test_dataset == "gsm8k" or test_dataset == "prm800k"):
                bracket_stats[bracket_class]['correct'] += 1
        elif pred_answer is None:
            null += 1

    bracket_accuracies = {}
    for group, stats in bracket_stats.items():

        acc = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        bracket_accuracies[group] = acc

    accuracy = correct / len(gold_answers)

    return accuracy, null, bracket_accuracies

def main():
    models = ['llama-3.1-8b-instruction']
    embs = ["all-roberta-large-v1"]
    ks = [4]
    base_model = '-'.join(models[0].split('-')[:-1])
    print(base_model)
    methods = ["random", "knn", "diversity", "knn_diversity",  "k_means"]
    
    base_dir =  './results'

    test_datasets = ["gsm8k"]
    
    train_datasets = test_datasets
    
    seed = 1 if ks == [0] else 10

    accuracy_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))))

    param_combinations = itertools.product(
        train_datasets, test_datasets, models, embs, ks, methods
    )
        
    for train_dataset, test_dataset, model, emb, k, method in param_combinations:
        _seed = 1 if 'knn' in method or 'k_means' in method else seed
        _, _, _, gold_answers = get_dataset(dataset=test_dataset, load_from_local=True)
        
        acc_groups = {0: [], 1: [], 2: []}
        acc_file_path = f"./data/{test_dataset}/{test_dataset}_acc_{base_model}-checkpoint-233.jsonl"
        if os.path.exists(acc_file_path):
            with open(acc_file_path, 'r', encoding='utf-8') as f:
                for i, line in enumerate(f):
                    data = json.loads(line)
                    acc_value = data.get('acc', 0)
                    acc_groups[acc_value].append(i)
        else:
            print(f"Warning: File {acc_file_path} does not exist, cannot perform acc grouping")
            continue
            
        for i in range(_seed):
            if('knn' in method or 'k_means' in method):
                if(k > 3):
                    permutation = 7
                else:
                    permutation = math.factorial(k)
            else:   
                permutation = 1
            for perm in range(permutation):
                file_dir = f"{base_dir}/{method}/{model}/{test_dataset}/{train_dataset}/{k}/{perm}/{i}/{emb}.jsonl"
                
                if not os.path.exists(file_dir):
                    print(f"File {file_dir} does not exist")
                    continue
                    
                pred_answers = read_jsonl(file_dir=file_dir, key="answer")
                if len(pred_answers) != len(gold_answers):
                    print(f"Warning: Number of predictions in {file_dir} does not match dataset size")
                    continue
                
                accuracy, _, _ = calculate_accuracy(gold_answers, pred_answers, file_dir, test_dataset, model)
                accuracy_dict[model]['id'][method][k]['all'].append(accuracy)
                
                for acc_value, indices in acc_groups.items():
                    if not indices:
                        continue
                    
                    group_gold_answers = [gold_answers[idx] for idx in indices if idx < len(gold_answers)]
                    group_pred_answers = [pred_answers[idx] for idx in indices if idx < len(pred_answers)]
                    
                    if group_gold_answers and group_pred_answers:
                        group_accuracy, _, _ = calculate_accuracy(group_gold_answers, group_pred_answers, file_dir, test_dataset, model)
                        accuracy_dict[model]['id'][method][k][f'acc_{acc_value}'].append(group_accuracy)
                
    for model in models:
        print(f"Model: {model}")
        print('-'*100)
        print("id:")
        if(accuracy_dict[model]['id']):
            for method in methods:
                for k in ks:
                    if 'all' in accuracy_dict[model]['id'][method][k] and len(accuracy_dict[model]['id'][method][k]['all']) > 0:
                        if len(accuracy_dict[model]['id'][method][k]['all']) > 1:
                            print(f"Method: {method}, k: {k}, Overall accuracy: ${100*sum(accuracy_dict[model]['id'][method][k]['all']) / len(accuracy_dict[model]['id'][method][k]['all']):.2f}_{{{100*statistics.stdev(accuracy_dict[model]['id'][method][k]['all']):.2f}}}$")
                        else:
                            print(f"Method: {method}, k: {k}, Overall accuracy: ${100*accuracy_dict[model]['id'][method][k]['all'][0]:.2f}$")
                    
                    for acc_value in range(3):
                        group_key = f'acc_{acc_value}'
                        if group_key in accuracy_dict[model]['id'][method][k] and len(accuracy_dict[model]['id'][method][k][group_key]) > 0:
                            if len(accuracy_dict[model]['id'][method][k][group_key]) > 1:
                                print(f"Method: {method}, k: {k}, acc={acc_value} group accuracy: ${100*sum(accuracy_dict[model]['id'][method][k][group_key]) / len(accuracy_dict[model]['id'][method][k][group_key]):.2f}_{{{100*statistics.stdev(accuracy_dict[model]['id'][method][k][group_key]):.2f}}}$")
                            else:
                                print(f"Method: {method}, k: {k}, acc={acc_value} group accuracy: ${100*accuracy_dict[model]['id'][method][k][group_key][0]:.2f}$")
                        
if __name__ == "__main__":
    main()